{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flax SST-2 Example\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/google/flax/blob/main/examples/sst2/sst2.ipynb\" ><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "\n",
    "Demonstration notebook for\n",
    "https://github.com/google/flax/tree/main/examples/sst2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Before you start:** Select Runtime -> Change runtime type -> GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The **Flax Notebook Workflow**:\n",
    "\n",
    "1. Run the entire notebook end-to-end and check out the outputs.\n",
    "   - This will open Python files in the right-hand editor!\n",
    "   - You'll be able to interactively explore metrics in TensorBoard.\n",
    "2. Change `config` and train for different hyperparameters. Check out the\n",
    "   updated TensorBoard plots.\n",
    "3. Update the code in `train.py`. Thanks to `%autoreload`, any changes you\n",
    "   make in the file will automatically appear in the notebook. Some ideas to\n",
    "   get you started:\n",
    "   - Change the model.\n",
    "   - Log some per-batch metrics during training.\n",
    "   - Add new hyperparameters to `configs/default.py` and use them in\n",
    "     `train.py`.\n",
    "4. At any time, feel free to paste code from `train.py` into the notebook\n",
    "   and modify it directly there!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_directory = 'examples/sst2'\n",
    "editor_relpaths = ('configs/default.py', 'train.py', 'models.py')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (If you run this code in Jupyter[lab], then you're already in the\n",
    "#  example directory and nothing needs to be done.)\n",
    "\n",
    "#@markdown **Fetch newest Flax, copy example code**\n",
    "#@markdown\n",
    "#@markdown **If you select no** below, then the files will be stored on the\n",
    "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n",
    "#@markdown be restarted an any changes are lost**.\n",
    "#@markdown\n",
    "#@markdown **If you select yes** below, then you will be asked for your\n",
    "#@markdown credentials to mount your personal Google Drive. In this case, all\n",
    "#@markdown changes you make will be *persisted*, and even if you re-run the\n",
    "#@markdown Colab later on, the files will still be the same (you can of course\n",
    "#@markdown remove directories inside your Drive's `flax/` root if you want to\n",
    "#@markdown manually revert these files).\n",
    "\n",
    "if 'google.colab' in str(get_ipython()):\n",
    "  import os\n",
    "  os.chdir('/content')\n",
    "  # Download Flax repo from Github.\n",
    "  if not os.path.isdir('flaxrepo'):\n",
    "    !git clone --depth=1 https://github.com/google/flax flaxrepo\n",
    "  # Copy example files & change directory.\n",
    "  mount_gdrive = 'no' #@param ['yes', 'no']\n",
    "  if mount_gdrive == 'yes':\n",
    "    DISCLAIMER = 'Note: Editing in your Google Drive, changes will persist.'\n",
    "    from google.colab import drive\n",
    "    drive.mount('/content/gdrive')\n",
    "    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n",
    "  else:\n",
    "    DISCLAIMER = 'WARNING: Editing in VM - changes lost after reboot!!'\n",
    "    example_root_path = f'/content/{example_directory}'\n",
    "    from IPython import display\n",
    "    display.display(display.HTML(\n",
    "        f'<h1 style=\"color:red;\" class=\"blink\">{DISCLAIMER}</h1>'))\n",
    "  if not os.path.isdir(example_root_path):\n",
    "    os.makedirs(example_root_path)\n",
    "    !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n",
    "  os.chdir(example_root_path)\n",
    "  from google.colab import files\n",
    "  for relpath in editor_relpaths:\n",
    "    s = open(f'{example_root_path}/{relpath}').read()\n",
    "    open(f'{example_root_path}/{relpath}', 'w').write(\n",
    "        f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n",
    "    files.view(f'{example_root_path}/{relpath}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: In Colab, above cell changed the working directory.\n",
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install SST-2 dependencies.\n",
    "!pip install -q -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports / Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to use TPU instead of GPU, you need to run this to make it work.\n",
    "try:\n",
    "  import jax.tools.colab_tpu\n",
    "  jax.tools.colab_tpu.setup_tpu()\n",
    "except KeyError:\n",
    "  print('\\n### NO TPU CONNECTED - USING CPU or GPU ###\\n')\n",
    "  import os\n",
    "  os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n",
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from absl import logging\n",
    "import flax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import time\n",
    "logging.set_verbosity(logging.INFO)\n",
    "\n",
    "# Make sure the GPU is for JAX, not for TF.\n",
    "tf.config.experimental.set_visible_devices([], 'GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Local imports from current directory - auto reload.\n",
    "# Any changes you make to train.py will appear automatically.\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import train\n",
    "import models\n",
    "import vocabulary\n",
    "import input_pipeline\n",
    "from configs import default as config_lib\n",
    "config = config_lib.get_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Get datasets. \n",
    "# If you get an error you need to install tensorflow_datasets from Github.\n",
    "train_dataset = input_pipeline.TextDataset(split='train')\n",
    "eval_dataset = input_pipeline.TextDataset(split='validation')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a live update during training - use the \"refresh\" button!\n",
    "# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n",
    "if 'google.colab' in str(get_ipython()):\n",
    "  %load_ext tensorboard\n",
    "  %tensorboard --logdir=."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "config.num_epochs = 10\n",
    "model_name = 'bilstm'\n",
    "start_time = time.time()\n",
    "optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')\n",
    "logging.info('Walltime: %f s', time.time() - start_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "tags": []
   },
   "outputs": [],
   "source": [
    "if 'google.colab' in str(get_ipython()):\n",
    "  #@markdown You can upload the training results directly to https://tensorboard.dev\n",
    "  #@markdown\n",
    "  #@markdown Note that everbody with the link will be able to see the data.\n",
    "  upload_data = 'yes' #@param ['yes', 'no']\n",
    "  if upload_data == 'yes':\n",
    "    !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
